{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# Using Custom Ops with TF2ONNX\n",
    "\n",
    "The custom ops framework lets you define new ONNX operators in Python or C++ and load them into ORT.  This makes it possible to convert and run TF models with ops that have no current ONNX equivalent.  The framework also serves as a place for sharing custom op definitions.\n",
    "\n",
    "There are 3 main ways to use this framework:\n",
    "- Case 1: Converting a TF model using an existing custom op\n",
    "  - Best option if op is already implemented\n",
    "- Case 2: Defining new custom ops in Python to use in conversion\n",
    "  - Easier than C++ but perf might be poor\n",
    "- Case 3: Defining new custom ops in C++\n",
    "  - Likely better perf than Python but requires building the customops repo from source\n",
    "\n",
    "For cases 1 and 2, you can use the off-the-shelf pip package `onnxruntime_extensions`.  For case 3, you will need to clone and build the customops repo.  Follow the instructions [here](https://github.com/microsoft/ort-customops#getting-started).\n",
    "\n",
    "You will also need to install the onnxruntime, tensorflow, and tf2onnx packages.  **NOTE: tf2onnx version (FIXME) is required for this tutorial.**"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "## Case 1: Converting a TF model using an existing custom op\n",
    "\n",
    "First let's create a model that requires a custom op that is already defined in the custom ops framework"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model1(tf.keras.Model):\n",
    "\n",
    "    def __init__(self, name='model1', **kwargs):\n",
    "        super(Model1, self).__init__(name=name, **kwargs)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        return tf.strings.regex_replace(inputs, \" \", \"_\", replace_global=True)\n",
    "\n",
    "model1 = Model1()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Hello_world!'], dtype=object)>"
      ]
     },
     "metadata": {},
     "execution_count": 3
    }
   ],
   "source": [
    "model1(tf.constant([\"Hello world!\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WARNING:tensorflow:From c:\\Users\\tomwi\\Documents\\onnxenvtf23\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "WARNING:tensorflow:From c:\\Users\\tomwi\\Documents\\onnxenvtf23\\lib\\site-packages\\tensorflow\\python\\training\\tracking\\tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "INFO:tensorflow:Assets written to: saved_model1\\assets\n"
     ]
    }
   ],
   "source": [
    "model1.save(\"saved_model1\")"
   ]
  },
  {
   "source": [
    "### Identifying unsupported ops\n",
    "\n",
    "If a model has unsupported ops, tf2onnx will still convert it, but the unsupported ops will be left in the graph unchanged. An error message will list the unsupported ops."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m tf2onnx.convert --saved-model \"saved_model1\" --output \"model1.onnx\""
   ]
  },
  {
   "source": [
    "Loading a model with unsupported ops into ORT raises an error."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from model1.onnx failed:This is an invalid model. Error in Node:PartitionedCall/autoencoder/StaticRegexReplace : No Op registered for StaticRegexReplace with domain_version of 8\n"
     ]
    }
   ],
   "source": [
    "import onnxruntime as ort\n",
    "\n",
    "try:\n",
    "    sess = ort.InferenceSession(\"model1.onnx\")\n",
    "except Exception as e:\n",
    "    print(e)"
   ]
  },
  {
   "source": [
    "### Enabling custom ops in the converter\n",
    "\n",
    "Fortunately, in this case there is already a custom op implementing the functionality we need: StringRegexReplace.  The converter has a rule to replace TF's StaticRegexReplace op with the StringRegexReplace custom op.  To enable conversions that use custom ops, add the `--extra_opset ai.onnx.contrib:1` flag."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m tf2onnx.convert --saved-model \"saved_model1\" --output \"model1.onnx\" --extra_opset ai.onnx.contrib:1"
   ]
  },
  {
   "source": [
    "### Loading custom ops into ORT\n",
    "\n",
    "Pass the location of the custom ops library into the ORT session options to use the op."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Inputs: ['input_1:0']\n",
      "Outputs: ['Identity:0']\n"
     ]
    }
   ],
   "source": [
    "import onnxruntime as ort\n",
    "from onnxruntime_extensions import get_library_path\n",
    "\n",
    "so = ort.SessionOptions()\n",
    "so.register_custom_ops_library(get_library_path())\n",
    "\n",
    "sess = ort.InferenceSession(\"model1.onnx\", so)\n",
    "print(\"Inputs:\", [inp.name for inp in sess.get_inputs()])\n",
    "print(\"Outputs:\", [out.name for out in sess.get_outputs()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "[array(['Hello_World!'], dtype=object)]"
      ]
     },
     "metadata": {},
     "execution_count": 9
    }
   ],
   "source": [
    "sess.run([\"Identity:0\"], {\"input_1:0\": [\"Hello World!\"]})"
   ]
  },
  {
   "source": [
    "## Case 2: Defining new custom ops with Python\n",
    "\n",
    "If there is no existing custom op implementation, you will need to define the op yourself and add a conversion rule for it."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model2(tf.keras.Model):\n",
    "\n",
    "    def __init__(self, name='model2', **kwargs):\n",
    "        super(Model2, self).__init__(name=name, **kwargs)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x, segment_ids = inputs\n",
    "        num_segs = tf.reduce_max(segment_ids) + 1\n",
    "        return tf.strings.unsorted_segment_join(x, segment_ids, num_segs, separator='-')\n",
    "\n",
    "model2 = Model2()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'java-script', b'car-pet'], dtype=object)>"
      ]
     },
     "metadata": {},
     "execution_count": 12
    }
   ],
   "source": [
    "model2([tf.constant([\"car\", \"java\", \"pet\", \"script\"]), tf.constant([1, 0, 1, 0])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "INFO:tensorflow:Assets written to: saved_model2\\assets\n"
     ]
    }
   ],
   "source": [
    "model2.save(\"saved_model2\", save_format=\"tf\")"
   ]
  },
  {
   "source": [
    "### Adding a custom op conversion rule using the command line\n",
    "\n",
    "We need to tell the converter how to convert the TF DecodeGif op. Even if our custom op will have the same name as the TF op, the node must be tagged with the custom ops domain `ai.onnx.contrib`.\n",
    "\n",
    "Pass `--extra_opset ai.onnx.contrib:1` and `--custom-ops DecodeGif:ai.onnx.contrib` flags to the converter."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m tf2onnx.convert --saved-model \"saved_model2\" --output \"model2a.onnx\" --extra_opset ai.onnx.contrib:1 --custom-ops UnsortedSegmentJoin:ai.onnx.contrib"
   ]
  },
  {
   "source": [
    "### Adding a custom op conversion rule using python\n",
    "\n",
    "For more complicated conversions, the rule can be defined using python.  See the [tf2onnx repo](https://github.com/onnx/tensorflow-onnx/tree/master/tf2onnx/onnx_opset) for more conversion rule examples."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "source": [
    "import numpy as np\n",
    "from tf2onnx import utils, constants\n",
    "from tf2onnx.handler import tf_op\n",
    "\n",
    "# Registers a conversion rule for UnsortedSegmentJoin op\n",
    "# Rule will only be run if ai.onnx.contrib domain is included via --extra_opset flag\n",
    "@tf_op(\"UnsortedSegmentJoin\", domain=constants.CONTRIB_OPS_DOMAIN)\n",
    "class ConvertUnsortedSegmentJoinOp:\n",
    "    @classmethod\n",
    "    def version_1(cls, ctx, node, **kwargs):\n",
    "        node.type = \"MyCustomStringSegmentJoin\"\n",
    "        # Don't forget to set the domain!\n",
    "        node.domain = constants.CONTRIB_OPS_DOMAIN\n",
    "        # Ops defined using the custom ops framework only get access to inputs, not attributes\n",
    "        separator = node.get_attr_str(\"separator\") if \"separator\" in node.attr else ''\n",
    "        for a in list(node.attr.keys()):\n",
    "            del node.attr[a]\n",
    "        # Add the separator as an additional string input\n",
    "        separator_const = ctx.make_const(utils.make_name('separator_const'), np.array([separator], dtype=np.object))\n",
    "        ctx.replace_inputs(node, node.input + [separator_const.output[0]])"
   ],
   "cell_type": "code",
   "metadata": {},
   "execution_count": 15,
   "outputs": []
  },
  {
   "source": [
    "Next, call the converter using the [tf2onnx Python API](https://github.com/onnx/tensorflow-onnx#python-api-reference). All rules decorated with `@tf_op` will be used."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Inputs: ['inputs:0', 'inputs_1:0']\nOutputs: ['Identity:0']\n"
     ]
    }
   ],
   "source": [
    "concrete_fn2 = tf.function(model2.call).get_concrete_function([tf.TensorSpec([None], tf.string), tf.TensorSpec([None], tf.int32)])\n",
    "input_names = [inp.name for inp in concrete_fn2.inputs]\n",
    "output_names = [out.name for out in concrete_fn2.outputs]\n",
    "print(\"Inputs:\", input_names)\n",
    "print(\"Outputs:\", output_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WARNING:tensorflow:From c:\\Users\\tomwi\\OneDrive - Microsoft\\ONNX\\tensorflow-onnx\\tf2onnx\\tf_loader.py:493: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "Conversion complete!\n"
     ]
    }
   ],
   "source": [
    "from tf2onnx import tf_loader\n",
    "from tf2onnx.tfonnx import process_tf_graph\n",
    "from tf2onnx.optimizer import optimize_graph\n",
    "\n",
    "graph_def = tf_loader.from_function(concrete_fn2, input_names=input_names, output_names=output_names)\n",
    "extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)]\n",
    "with tf.Graph().as_default() as tf_graph:\n",
    "    tf.import_graph_def(graph_def, name='')\n",
    "with tf_loader.tf_session(graph=tf_graph):\n",
    "    g = process_tf_graph(tf_graph, input_names=input_names, output_names=output_names, extra_opset=extra_opset)\n",
    "onnx_graph = optimize_graph(g)\n",
    "model_proto = onnx_graph.make_model(\"converted\")\n",
    "utils.save_protobuf(\"model2b.onnx\", model_proto)\n",
    "print(\"Conversion complete!\")"
   ]
  },
  {
   "source": [
    "### Implementing the op in python\n",
    "\n",
    "Add a function with the `@onnx_op` decorator to register a custom op before creating the ORT InferenceSession.  The inputs will be passed in as numpy arrays, and a numpy array of the declared type should be returned.  \n",
    "\n",
    "**NOTE:** ORT only will allow an op to be registered once, so you must restart the Jupyter kernel each time you change the implementation below."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from onnxruntime_extensions import onnx_op, PyCustomOpDef\n",
    "\n",
    "@onnx_op(op_type=\"UnsortedSegmentJoin\",\n",
    "         inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_int32],\n",
    "         outputs=[PyCustomOpDef.dt_string])\n",
    "def unsorted_segment_join(x, segment_ids, num_segments):\n",
    "    # The custom op implementation.\n",
    "    result = np.full([num_segments], '', dtype=np.object)\n",
    "    for s, seg_id in zip(x, segment_ids):\n",
    "        result[seg_id] += s\n",
    "    return result\n",
    "\n",
    "@onnx_op(op_type=\"MyCustomStringSegmentJoin\",\n",
    "         inputs=[PyCustomOpDef.dt_string, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_int32, PyCustomOpDef.dt_string],\n",
    "         outputs=[PyCustomOpDef.dt_string])\n",
    "def string_segment_join(x, segment_ids, num_segments, separator):\n",
    "    result = [[] for i in range(num_segments)]\n",
    "    separator = separator[0]\n",
    "    for s, seg_id in zip(x, segment_ids):\n",
    "        result[seg_id].append(s)\n",
    "    result_joined = [separator.join(l) for l in result]\n",
    "    return np.array(result_joined, dtype=np.object)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[array(['javascript', 'carpet'], dtype=object)]\n[array(['java-script', 'car-pet'], dtype=object)]\n"
     ]
    }
   ],
   "source": [
    "import onnxruntime as ort\n",
    "from onnxruntime_extensions import get_library_path\n",
    "\n",
    "so = ort.SessionOptions()\n",
    "so.register_custom_ops_library(get_library_path())\n",
    "\n",
    "sess = ort.InferenceSession(\"model2a.onnx\", so)\n",
    "# Use the input names from the saved_model_cli\n",
    "print(sess.run([\"Identity:0\"], {\"input_1:0\": [\"car\", \"java\", \"pet\", \"script\"], \"input_2:0\": [1, 0, 1, 0]}))\n",
    "\n",
    "sess = ort.InferenceSession(\"model2b.onnx\", so)\n",
    "# Use the input names from the concrete function\n",
    "print(sess.run([\"Identity:0\"], {input_names[0]: [\"car\", \"java\", \"pet\", \"script\"], input_names[1]: [1, 0, 1, 0]}))"
   ]
  },
  {
   "source": [
    "## Case 3: Implementing custom ops in C++\n",
    "\n",
    "Add a conversion rule for your custom op using the instructions in the previous section.  It can be useful to prototype the op in python before developing a C++ version.  Follow the [C++ Custom Ops Tutorial](https://github.com/microsoft/ort-customops/blob/main/tutorials/cpp_custom_ops_tutorial.md) to create a C++ version of the op."
   ],
   "cell_type": "markdown",
   "metadata": {}
  }
 ]
}